"""
火山引擎TTS客户端 - 基于官方协议完全重构
"""
import asyncio
import copy
import json
import logging
import uuid
import io
from typing import Optional, AsyncGenerator, Dict, Any

import websockets
import numpy as np

# 音频处理库
try:
    from pydub import AudioSegment
    AUDIO_PROCESSING_AVAILABLE = True
except ImportError:
    AUDIO_PROCESSING_AVAILABLE = False

from .protocols import (
    EventType,
    MsgType,
    finish_connection,
    finish_session,
    receive_message,
    start_connection,
    start_session,
    task_request,
    wait_for_event,
)
from .errors import (
    TTSError,
    TTSConnectionError,
    TTSAuthenticationError,
    TTSQuotaExceededError,
    TTSRateLimitError,
    TTSServerError,
    TTSSessionError,
    TTSErrorCode,
    create_tts_exception,
    get_quota_solution,
)
from config.config_manager import TTSConfig

logger = logging.getLogger(__name__)


class TTSClient:
    """火山引擎TTS客户端 - 基于官方协议重构"""
    
    def __init__(self, config: TTSConfig):
        self.config = config
        self.websocket = None
        self._connection_lock = asyncio.Lock()
        
    def _get_resource_id(self, voice: str) -> str:
        """根据音色获取资源ID"""
        if voice.startswith("S_"):
            return "volc.megatts.default"
        return "volc.service_type.10029"
    
    async def connect(self, force_reconnect: bool = False) -> None:
        """建立WebSocket连接"""
        async with self._connection_lock:
            # 检查连接状态 - 兼容不同版本的websockets库
            is_connected = (self.websocket and
                          not getattr(self.websocket, 'closed', True) and
                          self.websocket.state.name == 'OPEN')

            if is_connected and not force_reconnect:
                return

            # 关闭现有连接
            if self.websocket:
                await self.close()
                
            # 构建请求头
            headers = {
                "X-Api-App-Key": self.config.app_id,
                "X-Api-Access-Key": self.config.access_token,
                "X-Api-Resource-Id": self.config.resource_id,
                "X-Api-Connect-Id": str(uuid.uuid4()),
            }
            
            logger.info(f"连接到 {self.config.ws_url}")
            
            # 建立WebSocket连接
            self.websocket = await websockets.connect(
                self.config.ws_url,
                additional_headers=headers,
                max_size=10 * 1024 * 1024,
                ping_interval=None  # 禁用客户端ping
            )
            
            logger.info(f"WebSocket连接建立成功, Logid: {self.websocket.response.headers.get('x-tt-logid', 'N/A')}")
            
            # 发送开始连接事件
            await start_connection(self.websocket)

            try:
                await wait_for_event(self.websocket, MsgType.FullServerResponse, EventType.ConnectionStarted)
                logger.info("TTS连接初始化完成")
            except Exception as e:
                # 检查是否为认证错误
                if "ConnectionFailed" in str(e):
                    raise TTSAuthenticationError(
                        "TTS连接认证失败，请检查APP_ID和ACCESS_TOKEN",
                        error_code=TTSErrorCode.CLIENT_AUTH_FAILED
                    )
                raise TTSConnectionError(f"TTS连接初始化失败: {e}")
    
    async def close(self) -> None:
        """关闭连接"""
        if self.websocket:
            try:
                # 检查连接是否仍然有效
                is_open = (not getattr(self.websocket, 'closed', True) and
                          getattr(self.websocket, 'state', None) and
                          self.websocket.state.name == 'OPEN')

                if is_open:
                    await finish_connection(self.websocket)
                    await wait_for_event(self.websocket, MsgType.FullServerResponse, EventType.ConnectionFinished)

                await self.websocket.close()
                logger.info("TTS连接已关闭")
            except Exception as e:
                logger.warning(f"关闭连接时出错: {e}")
            finally:
                self.websocket = None
    
    def _create_base_request(self, speaker: str) -> Dict[str, Any]:
        """创建基础请求参数"""
        # 构建音频参数
        audio_params = {
            "format": self.config.format,
            "sample_rate": self.config.sample_rate,
            "enable_timestamp": self.config.enable_timestamp,
        }

        # 添加可选的音频参数
        if self.config.speech_rate != 1.0:  # 改为1.0作为默认值判断
            audio_params["speed_ratio"] = self.config.speech_rate  # 使用官方参数名

        if self.config.loudness_rate != 1.0:  # 改为1.0作为默认值判断
            audio_params["volume_ratio"] = self.config.loudness_rate  # 使用官方参数名

        if self.config.emotion:
            audio_params["emotion"] = self.config.emotion

        if self.config.emotion_scale:
            audio_params["emotion_scale"] = self.config.emotion_scale

        # 构建附加参数
        additions = {
            "disable_markdown_filter": self.config.disable_markdown_filter,
            "disable_emoji_filter": self.config.disable_emoji_filter,
            "enable_language_detector": self.config.enable_language_detector,
        }

        # 添加可选的附加参数
        if self.config.silence_duration:
            additions["silence_duration"] = self.config.silence_duration

        if self.config.explicit_language:
            additions["explicit_language"] = self.config.explicit_language

        return {
            "user": {
                "uid": self.config.user_id,
            },
            "namespace": self.config.namespace,
            "req_params": {
                "speaker": speaker,
                "audio_params": audio_params,
                "additions": json.dumps(additions),
            },
        }
    
    async def synthesize_stream(self, text: str, speaker: Optional[str] = None) -> AsyncGenerator[bytes, None]:
        """流式合成文本为语音"""
        if not speaker:
            speaker = self.config.default_speaker
            
        # 确保连接
        await self.connect()
        
        try:
            # 创建会话
            session_id = str(uuid.uuid4())
            base_request = self._create_base_request(speaker)
            
            # 开始会话
            start_session_request = copy.deepcopy(base_request)
            start_session_request["event"] = EventType.StartSession
            
            await start_session(
                self.websocket,
                json.dumps(start_session_request).encode(),
                session_id
            )
            await wait_for_event(self.websocket, MsgType.FullServerResponse, EventType.SessionStarted)
            
            logger.debug(f"会话 {session_id} 已启动")
            
            # 发送文本合成任务
            async def send_text():
                synthesis_request = copy.deepcopy(base_request)
                synthesis_request["event"] = EventType.TaskRequest
                synthesis_request["req_params"]["text"] = text
                
                await task_request(
                    self.websocket,
                    json.dumps(synthesis_request).encode(),
                    session_id
                )
                
                # 结束会话
                await finish_session(self.websocket, session_id)
            
            # 启动发送任务
            send_task = asyncio.create_task(send_text())
            
            # 接收音频数据
            audio_chunks_received = 0
            total_audio_bytes = 0
            
            try:
                while True:
                    msg = await receive_message(self.websocket)
                    
                    if msg.type == MsgType.FullServerResponse:
                        if msg.event == EventType.SessionFinished:
                            logger.debug(f"会话 {session_id} 已完成，共接收 {audio_chunks_received} 个音频块，总字节: {total_audio_bytes}")
                            break
                        elif msg.event == EventType.SessionFailed:
                            logger.error(f"会话失败: {msg.payload}")
                            break
                    elif msg.type == MsgType.AudioOnlyServer:
                        if msg.payload:
                            audio_chunks_received += 1
                            chunk_len = len(msg.payload)
                            total_audio_bytes += chunk_len
                            logger.debug(f"收到音频块 {audio_chunks_received}，长度: {chunk_len}")
                            yield msg.payload
                    elif msg.type == MsgType.Error:
                        error_code = msg.error_code
                        error_msg = f"收到错误消息: {msg}"
                        logger.error(error_msg)

                        # 创建具体的异常
                        exception = create_tts_exception(error_code, msg.payload)

                        # 如果是配额错误，提供解决方案
                        if error_code == TTSErrorCode.CLIENT_QUOTA_EXCEEDED:
                            logger.warning(get_quota_solution(error_code))

                        raise exception
                    else:
                        logger.debug(f"收到其他消息: {msg}")
                        
            except Exception as e:
                logger.error(f"接收音频数据时出错: {e}")
                raise
            finally:
                # 等待发送任务完成
                try:
                    await send_task
                except Exception as e:
                    logger.warning(f"发送任务异常: {e}")
                    
        except Exception as e:
            logger.error(f"语音合成失败: {e}")
            raise
    
    async def synthesize_stream_bidi(
        self,
        text_queue: asyncio.Queue,
        speaker: Optional[str] = None
    ) -> AsyncGenerator[bytes, None]:
        """双向流式合成 - 边发边收"""
        if not speaker:
            speaker = self.config.default_speaker
            
        # 确保连接
        await self.connect()
        
        try:
            # 创建会话
            session_id = str(uuid.uuid4())
            base_request = self._create_base_request(speaker)
            
            # 开始会话
            start_session_request = copy.deepcopy(base_request)
            start_session_request["event"] = EventType.StartSession
            
            await start_session(
                self.websocket,
                json.dumps(start_session_request).encode(),
                session_id
            )
            await wait_for_event(self.websocket, MsgType.FullServerResponse, EventType.SessionStarted)
            
            logger.debug(f"双向流式会话 {session_id} 已启动")
            
            # 发送文本任务
            async def send_texts():
                try:
                    while True:
                        text = await text_queue.get()
                        if text is None:  # 结束标记
                            await finish_session(self.websocket, session_id)
                            break
                            
                        synthesis_request = copy.deepcopy(base_request)
                        synthesis_request["event"] = EventType.TaskRequest
                        synthesis_request["req_params"]["text"] = text
                        
                        await task_request(
                            self.websocket,
                            json.dumps(synthesis_request).encode(),
                            session_id
                        )
                        
                        # 小延迟避免过快发送
                        await asyncio.sleep(0.005)
                        
                except Exception as e:
                    logger.error(f"发送文本时出错: {e}")
                    raise
            
            # 启动发送任务
            send_task = asyncio.create_task(send_texts())
            
            # 接收音频数据
            try:
                while True:
                    msg = await receive_message(self.websocket)
                    
                    if msg.type == MsgType.FullServerResponse:
                        if msg.event == EventType.SessionFinished:
                            logger.debug(f"双向流式会话 {session_id} 已完成")
                            break
                        elif msg.event == EventType.SessionFailed:
                            logger.error(f"双向流式会话失败: {msg.payload}")
                            break
                    elif msg.type == MsgType.AudioOnlyServer:
                        if msg.payload:
                            yield msg.payload
                    elif msg.type == MsgType.Error:
                        error_code = msg.error_code
                        error_msg = f"双向流式收到错误消息: {msg}"
                        logger.error(error_msg)

                        # 创建具体的异常
                        exception = create_tts_exception(error_code, msg.payload)

                        # 如果是配额错误，提供解决方案
                        if error_code == TTSErrorCode.CLIENT_QUOTA_EXCEEDED:
                            logger.warning(get_quota_solution(error_code))

                        raise exception
                    else:
                        logger.debug(f"收到其他消息: {msg}")
                        
            except Exception as e:
                logger.error(f"双向流式接收音频数据时出错: {e}")
                raise
            finally:
                # 等待发送任务完成
                try:
                    await send_task
                except Exception as e:
                    logger.warning(f"双向流式发送任务异常: {e}")
                    
        except Exception as e:
            logger.error(f"双向流式语音合成失败: {e}")
            raise
    
    async def synthesize_text(self, text: str, speaker: Optional[str] = None) -> bytes:
        """合成完整文本为音频数据"""
        audio_data = bytearray()
        
        async for chunk in self.synthesize_stream(text, speaker):
            audio_data.extend(chunk)
        
        return bytes(audio_data)
